/********************************************************************************
*                                                                               *
* netinfo.cpp --       The net performance                                      *
*                                                                               *
* Copyright (c) Fengren Technology(Guangzhou) Co.LTD. All rights reserved.      *
*                                                                               *
********************************************************************************/

#include "netinfo.h"
#include <iostream>
#include <sstream>
#include <iomanip>
#include <winsock2.h>
#include <ws2ipdef.h>
#include <ws2tcpip.h>
#include <tcpmib.h>
#include <iphlpapi.h>


class NetworkPerformanceScanner
{
public:
    NetworkPerformanceScanner();
    ~NetworkPerformanceScanner();

    static std::vector<NetworkPerformanceItem> ScanNetworkPerformanceTCP4(unsigned long processId, bool resv);
    static std::vector<NetworkPerformanceItem> ScanNetworkPerformanceTCP6(unsigned long processId, bool resv);
    //static std::vector<NetworkPerformanceItem> ScanNetworkPerformanceUDP4(unsigned long processId);
    //static std::vector<NetworkPerformanceItem> ScanNetworkPerformanceUDP6(unsigned long processId);
};


NetworkPerformanceScanner::NetworkPerformanceScanner()
{

}


NetworkPerformanceScanner::~NetworkPerformanceScanner()
{
}

// TODO - implement TCP v6, UDP
std::vector<NetworkPerformanceItem> NetworkPerformanceScanner::ScanNetworkPerformanceTCP4(unsigned long processId, bool resv)
{
    std::vector<unsigned char> buffer;
    DWORD dwSize = sizeof(MIB_TCPTABLE_OWNER_PID);
    DWORD dwRetValue = 0;
    std::vector<NetworkPerformanceItem> networkPerformanceItems;
    // repeat till buffer is big enough
    do
    {
        buffer.resize(dwSize, 0);
        dwRetValue = GetExtendedTcpTable(buffer.data(), &dwSize, TRUE, AF_INET, TCP_TABLE_OWNER_PID_ALL, 0);
    } while (dwRetValue == ERROR_INSUFFICIENT_BUFFER);

    if (dwRetValue == ERROR_SUCCESS)
    {
        // good case
        // cast to access element values
        PMIB_TCPTABLE_OWNER_PID ptTable = reinterpret_cast<PMIB_TCPTABLE_OWNER_PID>(buffer.data());

        // caution: array starts with index 0, count starts by 1
        for (DWORD i = 0; i < ptTable->dwNumEntries; i++)
        {
            if (ptTable->table[i].dwOwningPid == processId) {
                NetworkPerformanceItem networkPerformanceItem;

                networkPerformanceItem.ProcessId = ptTable->table[i].dwOwningPid;
                networkPerformanceItem.State = ptTable->table[i].dwState;
                networkPerformanceItem.LocalPort = ntohs(ptTable->table[i].dwLocalPort);
                networkPerformanceItem.RemotePort = ntohs(ptTable->table[i].dwRemotePort);
                // networkPerformanceItem.RemoteAddress = remoteStream.str();
                if (resv) {
                    in_addr p;
                    TCHAR buff[160];
                    p.S_un.S_addr = ptTable->table[i].dwLocalAddr;
                    networkPerformanceItem.LocalAddress = InetNtop(AF_INET, &p, buff, 160); // inet_ntoa(p);
                    TCHAR buff2[160];
                    p.S_un.S_addr = ptTable->table[i].dwRemoteAddr;
                    networkPerformanceItem.RemoteAddress = InetNtop(AF_INET, &p, buff2, 160); // inet_ntoa(p);
                }

                MIB_TCPROW row;
                row.dwLocalAddr = ptTable->table[i].dwLocalAddr;
                row.dwLocalPort = ptTable->table[i].dwLocalPort;
                row.dwRemoteAddr = ptTable->table[i].dwRemoteAddr;
                row.dwRemotePort = ptTable->table[i].dwRemotePort;
                row.dwState = ptTable->table[i].dwState;

                networkPerformanceItem.dwConnState = row.dwState;

                // void* processRow = &row;

                if (row.dwRemoteAddr != 0)
                {
                    ULONG rosSize = 0, rodSize = 0;
                    ULONG winStatus;
                    PUCHAR ros = NULL, rod = NULL;
                    rodSize = sizeof(TCP_ESTATS_DATA_ROD_v0);
                    PTCP_ESTATS_DATA_ROD_v0 dataRod = { 0 };
                    //So, ros will not be used for est data and bandwidth
                    if (rosSize != 0) {
                        ros = (PUCHAR)malloc(rosSize);
                        if (ros == NULL) {
                            return networkPerformanceItems;
                        }
                        else
                            memset(ros, 0, rosSize); // zero the buffer
                    }
                    if (rodSize != 0) {
                        rod = (PUCHAR)malloc(rodSize);
                        if (rod == NULL) {
                            if (ros != NULL) {
                                free(ros);
                                ros = NULL;
                            }
                            return networkPerformanceItems;
                        }
                        else
                            memset(rod, 0, rodSize); // zero the buffer
                    }

                    TCP_ESTATS_DATA_RW_v0 DataRw;
                    DataRw.EnableCollection = TRUE;
                    TCP_ESTATS_BANDWIDTH_RW_v0 Bandwidth;
                    Bandwidth.EnableCollectionInbound = TcpBoolOptEnabled;
                    Bandwidth.EnableCollectionOutbound = TcpBoolOptEnabled;

                    winStatus = SetPerTcpConnectionEStats((PMIB_TCPROW)&row, TcpConnectionEstatsData, (BYTE*)&DataRw, 0, sizeof(TCP_ESTATS_DATA_RW_v0), 0);
                    winStatus = SetPerTcpConnectionEStats((PMIB_TCPROW)&row, TcpConnectionEstatsBandwidth, (BYTE*)&Bandwidth, 0, sizeof(TCP_ESTATS_BANDWIDTH_RW_v0), 0);

                    winStatus = GetPerTcpConnectionEStats((PMIB_TCPROW)&row, TcpConnectionEstatsData, NULL, 0, 0, ros, 0, rosSize, rod, 0, rodSize);
                    if (winStatus == NO_ERROR && (row.State == MIB_TCP_STATE_LISTEN || row.State == MIB_TCP_STATE_SYN_SENT || row.State == MIB_TCP_STATE_SYN_RCVD || row.State == MIB_TCP_STATE_ESTAB)) {
                        dataRod = (PTCP_ESTATS_DATA_ROD_v0)rod;
                        //wchar_t buf[512] = { 0 };
                        //wsprintf(buf, L"%I64d -- %I64d (local: %d, remote: %d)\n", dataRod->DataBytesIn, dataRod->DataBytesOut, row.dwLocalPort, row.dwRemotePort);
                        //OutputDebugString(buf);
                        
                        // if the dataBytesIn and dataSegsIn are not equqal, the data may be valid
                        if (!(dataRod->DataBytesIn == dataRod->DataSegsIn && dataRod->DataBytesOut == dataRod->DataSegsOut)) {
                            if ((LONG64)dataRod->DataBytesIn > (LONG64)0 && (LONG64)dataRod->DataBytesOut > (LONG64)0) {
                                networkPerformanceItem.BytesIn = dataRod->DataBytesIn;
                                networkPerformanceItem.BytesOut = dataRod->DataBytesOut;
                            }
                        }
                    }

                    PTCP_ESTATS_BANDWIDTH_ROD_v0 bandwidthRod = { 0 };

                    if (rod != NULL) {
                        free(rod);
                        rod = NULL;
                    }

                    rodSize = sizeof(TCP_ESTATS_BANDWIDTH_ROD_v0);
                    if (rodSize != 0) {
                        rod = (PUCHAR)malloc(rodSize);
                        if (rod == NULL) {
                            if (ros != NULL) {
                                free(ros);
                                ros = NULL;
                            }
                            return networkPerformanceItems;
                        }
                        else
                            memset(rod, 0, rodSize); // zero the buffer
                    }

                    winStatus = GetPerTcpConnectionEStats((PMIB_TCPROW)&row, TcpConnectionEstatsBandwidth, NULL, 0, 0, ros, 0, rosSize, rod, 0, rodSize);
                    if (winStatus == NO_ERROR) {
                        bandwidthRod = (PTCP_ESTATS_BANDWIDTH_ROD_v0)rod;
                        networkPerformanceItem.OutboundBandwidth = bandwidthRod->InboundInstability / 8;
                        networkPerformanceItem.InboundBandwidth = bandwidthRod->OutboundInstability / 8;
                    }
                    if (rod != NULL) {
                        free(rod);
                        rod = NULL;
                    }
                    if (ros != NULL) {
                        free(ros);
                        ros = NULL;
                    }
                }
                networkPerformanceItem.Pass = 0;
                networkPerformanceItems.push_back(networkPerformanceItem);
            }
        }
    }

    return networkPerformanceItems;
}


// TODO - implement TCP v6, UDP
std::vector<NetworkPerformanceItem> NetworkPerformanceScanner::ScanNetworkPerformanceTCP6(unsigned long processId, bool resv)
{
    std::vector<unsigned char> buffer;
    DWORD dwSize = sizeof(MIB_TCP6TABLE_OWNER_PID);
    DWORD dwRetValue = 0;
    std::vector<NetworkPerformanceItem> networkPerformanceItems;

    // repeat till buffer is big enough
    do
    {
        buffer.resize(dwSize, 0);
        dwRetValue = GetExtendedTcpTable(buffer.data(), &dwSize, TRUE, AF_INET6, TCP_TABLE_OWNER_PID_ALL, 0);
    } while (dwRetValue == ERROR_INSUFFICIENT_BUFFER);

    if (dwRetValue == ERROR_SUCCESS)
    {
        // good case

        // cast to access element values
        PMIB_TCP6TABLE_OWNER_PID ptTable = reinterpret_cast<PMIB_TCP6TABLE_OWNER_PID>(buffer.data());
        // caution: array starts with index 0, count starts by 1
        for (DWORD i = 0; i < ptTable->dwNumEntries; i++)
        {
            if (ptTable->table[i].dwOwningPid == processId) {
                NetworkPerformanceItem networkPerformanceItem;
                
                networkPerformanceItem.ProcessId = ptTable->table[i].dwOwningPid;
                networkPerformanceItem.State = ptTable->table[i].dwState;
                networkPerformanceItem.dwConnState = ptTable->table[i].dwState;
                networkPerformanceItem.LocalPort = ntohs(ptTable->table[i].dwLocalPort);
                networkPerformanceItem.RemotePort = ntohs(ptTable->table[i].dwRemotePort);
                if (resv) {
                    TCHAR buff[320];
                    TCHAR buff2[320];
                    networkPerformanceItem.LocalAddress = InetNtop(AF_INET6, ptTable->table[i].ucLocalAddr, buff, 320); // inet_ntoa(p);
                    networkPerformanceItem.RemoteAddress = InetNtop(AF_INET6, ptTable->table[i].ucRemoteAddr, buff2, 320); // inet_ntoa(p);
                }

                MIB_TCP6ROW row;

                //row.LocalAddr = ptTable->table[i].ucLocalAddr;
                memcpy_s(&row.LocalAddr, sizeof(row.LocalAddr), ptTable->table[i].ucLocalAddr, sizeof(ptTable->table[i].ucLocalAddr));
                row.dwLocalScopeId = ptTable->table[i].dwLocalScopeId;
                row.dwLocalPort = ptTable->table[i].dwLocalPort;

                // row.RemoteAddr.u = ptTable->table[i].ucRemoteAddr;
                memcpy_s(&row.RemoteAddr, sizeof(row.RemoteAddr), ptTable->table[i].ucRemoteAddr, sizeof(ptTable->table[i].ucRemoteAddr));
                row.dwRemoteScopeId = ptTable->table[i].dwRemoteScopeId;
                row.dwRemotePort = ptTable->table[i].dwRemotePort;
                row.State = (MIB_TCP_STATE)ptTable->table[i].dwState;

                //void* processRow = &row;

                if (row.State != 0)
                {
                    ULONG rosSize = 0, rodSize = 0;
                    ULONG winStatus;
                    PUCHAR ros = NULL, rod = NULL;
                    rodSize = sizeof(TCP_ESTATS_DATA_ROD_v0);
                    PTCP_ESTATS_DATA_ROD_v0 dataRod = { 0 };

                    if (rosSize != 0) {
                        ros = (PUCHAR)malloc(rosSize);
                        if (ros == NULL) {
                            return networkPerformanceItems;
                        }
                        else
                            memset(ros, 0, rosSize); // zero the buffer
                    }
                    if (rodSize != 0) {
                        rod = (PUCHAR)malloc(rodSize);
                        if (rod == NULL) {
                            if (ros != NULL) {
                                free(ros);
                            }
                            return networkPerformanceItems;
                        }
                        else
                            memset(rod, 0, rodSize); // zero the buffer
                    }

                    TCP_ESTATS_DATA_RW_v0 DataRw;
                    DataRw.EnableCollection = TRUE;

                    TCP_ESTATS_BANDWIDTH_RW_v0 Bandwidth;
                    Bandwidth.EnableCollectionInbound = TcpBoolOptEnabled;
                    Bandwidth.EnableCollectionOutbound = TcpBoolOptEnabled;
                    
                    winStatus = SetPerTcp6ConnectionEStats((PMIB_TCP6ROW)&row, TcpConnectionEstatsData, (BYTE*)&DataRw, 0, sizeof(TCP_ESTATS_DATA_RW_v0), 0);
                    winStatus = SetPerTcp6ConnectionEStats((PMIB_TCP6ROW)&row, TcpConnectionEstatsBandwidth, (BYTE*)&Bandwidth, 0, sizeof(TCP_ESTATS_BANDWIDTH_RW_v0), 0);

                    winStatus = GetPerTcp6ConnectionEStats((PMIB_TCP6ROW)&row, TcpConnectionEstatsData, NULL, 0, 0, ros, 0, rosSize, rod, 0, rodSize);

                    if (winStatus == NO_ERROR && (row.State != MIB_TCP_STATE_CLOSED)) {
                        dataRod = (PTCP_ESTATS_DATA_ROD_v0)rod;
                        if (!(dataRod->DataBytesIn == dataRod->DataSegsIn && dataRod->DataBytesOut == dataRod->DataSegsOut) ) {
                            if ((LONG64)dataRod->DataBytesIn > (LONG64)0 && (LONG64)dataRod->DataBytesOut > (LONG64)0) {
                                networkPerformanceItem.BytesIn = dataRod->DataBytesIn;
                                networkPerformanceItem.BytesOut = dataRod->DataBytesOut;
                            }
                        }
                    }

                    PTCP_ESTATS_BANDWIDTH_ROD_v0 bandwidthRod = { 0 };
                    if (rod != NULL) {
                        free(rod);
                        rod = NULL;
                    }

                    rodSize = sizeof(TCP_ESTATS_BANDWIDTH_ROD_v0);
                    if (rodSize != 0) {
                        rod = (PUCHAR)malloc(rodSize);
                        if (rod == NULL) {
                            if (ros != NULL) {
                                free(ros);
                                ros = NULL;
                            }
                            return networkPerformanceItems;
                        }
                        else
                            memset(rod, 0, rodSize); // zero the buffer
                    }

                    winStatus = GetPerTcp6ConnectionEStats((PMIB_TCP6ROW)&row, TcpConnectionEstatsBandwidth, NULL, 0, 0, ros, 0, rosSize, rod, 0, rodSize);
                    if (winStatus == NO_ERROR) {
                        bandwidthRod = (PTCP_ESTATS_BANDWIDTH_ROD_v0)rod;
                        networkPerformanceItem.OutboundBandwidth = bandwidthRod->OutboundBandwidth;
                        networkPerformanceItem.InboundBandwidth = bandwidthRod->InboundBandwidth;
                    }
                    if (rod != NULL) {
                        free(rod);
                        rod = NULL;
                    }
                    if (ros != NULL) {
                        free(ros);
                        ros = NULL;
                    }
                }
                networkPerformanceItem.Pass = 0;
                networkPerformanceItems.push_back(networkPerformanceItem);
            }
        }
    }

    return networkPerformanceItems;
}

void GetProcessNetworkPerformance(NetworkPerformanceItem& item, std::vector<std::wstring>* ptr, DWORD dwProcessId) {
    std::vector<NetworkPerformanceItem> itemsTcp6 = NetworkPerformanceScanner::ScanNetworkPerformanceTCP6(dwProcessId, ptr != NULL);
    std::vector<NetworkPerformanceItem> itemsTcp4 = NetworkPerformanceScanner::ScanNetworkPerformanceTCP4(dwProcessId, ptr != NULL);
    item.ProcessId = dwProcessId;

    for (std::vector<NetworkPerformanceItem>::const_iterator it = itemsTcp6.begin(); it != itemsTcp6.end(); ++it) {
        if (it->ProcessId == dwProcessId) {
            item.BytesIn += it->BytesIn;
            item.BytesOut += it->BytesOut;
            item.InboundBandwidth += it->InboundBandwidth;
            item.OutboundBandwidth += it->OutboundBandwidth;
            if (ptr != NULL && it->dwConnState == MIB_TCP_STATE_LISTEN && it->LocalPort > 0) {
                std::wstring fullUrl;
                std::wstringstream kk;
                kk << it->LocalPort;
                fullUrl.append(L"[").append(it->LocalAddress).append(L"]:").append(kk.str());
                ptr->push_back(fullUrl);
            }
        }
    }
    
    for (std::vector<NetworkPerformanceItem>::const_iterator it = itemsTcp4.begin(); it != itemsTcp4.end(); ++it) {
        if (it->ProcessId == dwProcessId) {
            item.BytesIn += it->BytesIn;
            item.BytesOut += it->BytesOut;
            item.InboundBandwidth += it->InboundBandwidth;
            item.OutboundBandwidth += it->OutboundBandwidth;
            if (ptr != NULL && it->dwConnState == MIB_TCP_STATE_LISTEN && it->LocalPort > 0) {
                std::wstring fullUrl;
                std::wstringstream kk;
                kk << it->LocalPort;
                fullUrl.append(it->LocalAddress).append(L":").append(kk.str());
                ptr->push_back(fullUrl);
            }
        }
    }
}
